package com.jstarcraft.ai.jsat.regression;

import java.util.Arrays;

import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.linear.CholeskyDecomposition;
import com.jstarcraft.ai.jsat.linear.DenseMatrix;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.Matrix;
import com.jstarcraft.ai.jsat.linear.SingularValueDecomposition;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;

/**
 * An implementation of Ridge Regression that finds the exact solution. Ridge
 * Regression is equivalent to {@link MultipleLinearRegression} with an added
 * L<sub>2</sub> penalty for the weight vector. <br>
 * <br>
 * Two different methods of finding the solution can be used. This algorithm
 * should be used only for small dimensions problems with a reasonable number of
 * example points.<br>
 * For large dimension sparse problems, or dense problems with many data points
 * (or both), use the {@link StochasticRidgeRegression}. For small data sets
 * that pose non-linear problems, you can also use {@link KernelRidgeRegression}
 * 
 * @author Edward Raff
 */
public class RidgeRegression implements Regressor, Parameterized {

    private static final long serialVersionUID = -4605757038780391895L;
    private double lambda;
    private Vec w;
    private double bias;
    private SolverMode mode;

    /**
     * Sets which solver to use
     */
    public enum SolverMode {
        /**
         * Solves by {@link CholeskyDecomposition}
         */
        EXACT_CHOLESKY,
        /**
         * Solves by {@link SingularValueDecomposition}
         */
        EXACT_SVD,
    }

    public RidgeRegression() {
        this(1e-2);
    }

    public RidgeRegression(double regularization) {
        this(regularization, SolverMode.EXACT_CHOLESKY);
    }

    public RidgeRegression(double regularization, SolverMode mode) {
        setLambda(regularization);
        setSolverMode(mode);
    }

    /**
     * Sets the regularization parameter used.
     * 
     * @param lambda the positive regularization constant in (0, Inf)
     */
    public void setLambda(double lambda) {
        if (Double.isNaN(lambda) || Double.isInfinite(lambda) || lambda <= 0)
            throw new IllegalArgumentException("lambda must be a positive constant, not " + lambda);
        this.lambda = lambda;
    }

    /**
     * Returns the regularization constant in use
     * 
     * @return the regularization constant in use
     */
    public double getLambda() {
        return lambda;
    }

    /**
     * Sets which solver is to be used
     * 
     * @param mode the solver mode to use
     */
    public void setSolverMode(SolverMode mode) {
        this.mode = mode;
    }

    /**
     * Returns the solver in use
     * 
     * @return the solver to use
     */
    public SolverMode getSolverMode() {
        return mode;
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();

        return w.dot(x) + bias;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        final int dim = dataSet.getNumNumericalVars() + 1;
        DenseMatrix X = new DenseMatrix(dataSet.size(), dim);

        for (int i = 0; i < dataSet.size(); i++) {
            Vec from = dataSet.getDataPoint(i).getNumericalValues();
            X.set(i, 0, 1.0);
            for (int j = 0; j < from.length(); j++)
                X.set(i, j + 1, from.get(j));

        }

        final Vec Y = dataSet.getTargetValues();
        final boolean serial = !parallel;

        if (mode == SolverMode.EXACT_SVD) {
            SingularValueDecomposition svd = new SingularValueDecomposition(X);
            double[] ridgeD;
            ridgeD = Arrays.copyOf(svd.getSingularValues(), dim);
            for (int i = 0; i < ridgeD.length; i++)
                ridgeD[i] = 1 / (Math.pow(ridgeD[i], 2) + lambda);
            Matrix U = svd.getU();
            Matrix V = svd.getV();

            // w = V (D^2 + lambda I)^(-1) D U^T y
            Matrix.diagMult(V, DenseVector.toDenseVec(ridgeD));
            Matrix.diagMult(V, DenseVector.toDenseVec(svd.getSingularValues()));
            w = V.multiply(U.transpose()).multiply(Y);
        } else// cholesky
        {

            Matrix H = serial ? X.transposeMultiply(X) : X.transposeMultiply(X, ParallelUtils.CACHED_THREAD_POOL);
            // H + I * reg equiv to
            // H.mutableAdd(Matrix.eye(H.rows()).multiply(regularization));
            for (int i = 0; i < H.rows(); i++)
                H.increment(i, i, lambda);
            CholeskyDecomposition cd = serial ? new CholeskyDecomposition(H) : new CholeskyDecomposition(H, ParallelUtils.CACHED_THREAD_POOL);
            w = cd.solve(Matrix.eye(H.rows())).multiply(X.transpose()).multiply(Y);
        }

        // reformat w and seperate out bias term
        bias = w.get(0);
        Vec newW = new DenseVector(w.length() - 1);
        for (int i = 0; i < newW.length(); i++)
            newW.set(i, w.get(i + 1));
        w = newW;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public RidgeRegression clone() {
        RidgeRegression clone = new RidgeRegression(lambda);
        if (this.w != null)
            clone.w = this.w.clone();
        clone.bias = this.bias;
        return clone;
    }
}
